from __future__ import annotations

import asyncio
import copy
import sys
import time
from collections import deque
from typing import Any, Deque, Dict, List, Tuple

import fastapi

from gradio.data_classes import Estimation, PredictBody, Progress, ProgressUnit
from gradio.helpers import TrackedIterable
from gradio.utils import AsyncRequest, run_coro_in_background, set_task_name


class Event:
    def __init__(
        self,
        websocket: fastapi.WebSocket,
        session_hash: str,
        fn_index: int,
    ):
        self.websocket = websocket
        self.session_hash: str = session_hash
        self.fn_index: int = fn_index
        self._id = f"{self.session_hash}_{self.fn_index}"
        self.data: PredictBody | None = None
        self.lost_connection_time: float | None = None
        self.token: str | None = None
        self.progress: Progress | None = None
        self.progress_pending: bool = False

    async def disconnect(self, code: int = 1000):
        await self.websocket.close(code=code)


class Queue:
    def __init__(
        self,
        live_updates: bool,
        concurrency_count: int,
        update_intervals: float,
        max_size: int | None,
        blocks_dependencies: List,
    ):
        self.event_queue: Deque[Event] = deque()
        self.events_pending_reconnection = []
        self.stopped = False
        self.max_thread_count = concurrency_count
        self.update_intervals = update_intervals
        self.active_jobs: List[None | List[Event]] = [None] * concurrency_count
        self.delete_lock = asyncio.Lock()
        self.server_path = None
        self.duration_history_total = 0
        self.duration_history_count = 0
        self.avg_process_time = 0
        self.avg_concurrent_process_time = None
        self.queue_duration = 1
        self.live_updates = live_updates
        self.sleep_when_free = 0.05
        self.progress_update_sleep_when_free = 0.1
        self.max_size = max_size
        self.blocks_dependencies = blocks_dependencies
        self.access_token = ""

    async def start(self, progress_tracking=False):
        run_coro_in_background(self.start_processing)
        if progress_tracking:
            run_coro_in_background(self.start_progress_tracking)
        if not self.live_updates:
            run_coro_in_background(self.notify_clients)

    def close(self):
        self.stopped = True

    def resume(self):
        self.stopped = False

    def set_url(self, url: str):
        self.server_path = url

    def set_access_token(self, token: str):
        self.access_token = token

    def get_active_worker_count(self) -> int:
        count = 0
        for worker in self.active_jobs:
            if worker is not None:
                count += 1
        return count

    def get_events_in_batch(self) -> Tuple[List[Event] | None, bool]:
        if not (self.event_queue):
            return None, False

        first_event = self.event_queue.popleft()
        events = [first_event]

        event_fn_index = first_event.fn_index
        batch = self.blocks_dependencies[event_fn_index]["batch"]

        if batch:
            batch_size = self.blocks_dependencies[event_fn_index]["max_batch_size"]
            rest_of_batch = [
                event for event in self.event_queue if event.fn_index == event_fn_index
            ][: batch_size - 1]
            events.extend(rest_of_batch)
            [self.event_queue.remove(event) for event in rest_of_batch]

        return events, batch

    async def start_processing(self) -> None:
        while not self.stopped:
            if not self.event_queue:
                await asyncio.sleep(self.sleep_when_free)
                continue

            if not (None in self.active_jobs):
                await asyncio.sleep(self.sleep_when_free)
                continue
            # Using mutex to avoid editing a list in use
            async with self.delete_lock:
                events, batch = self.get_events_in_batch()

            if events:
                self.active_jobs[self.active_jobs.index(None)] = events
                task = run_coro_in_background(self.process_events, events, batch)
                run_coro_in_background(self.broadcast_live_estimations)
                set_task_name(task, events[0].session_hash, events[0].fn_index, batch)

    async def start_progress_tracking(self) -> None:
        while not self.stopped:
            if not any(self.active_jobs):
                await asyncio.sleep(self.progress_update_sleep_when_free)
                continue

            for job in self.active_jobs:
                if job is None:
                    continue
                for event in job:
                    if event.progress_pending and event.progress:
                        event.progress_pending = False
                        client_awake = await self.send_message(
                            event, event.progress.dict()
                        )
                        if not client_awake:
                            await self.clean_event(event)

            await asyncio.sleep(self.progress_update_sleep_when_free)

    def set_progress(
        self,
        event_id: str,
        iterables: List[TrackedIterable] | None,
    ):
        if iterables is None:
            return
        for job in self.active_jobs:
            if job is None:
                continue
            for evt in job:
                if evt._id == event_id:
                    progress_data: List[ProgressUnit] = []
                    for iterable in iterables:
                        progress_unit = ProgressUnit(
                            index=iterable.index,
                            length=iterable.length,
                            unit=iterable.unit,
                            progress=iterable.progress,
                            desc=iterable.desc,
                        )
                        progress_data.append(progress_unit)
                    evt.progress = Progress(progress_data=progress_data)
                    evt.progress_pending = True

    def push(self, event: Event) -> int | None:
        """
        Add event to queue, or return None if Queue is full
        Parameters:
            event: Event to add to Queue
        Returns:
            rank of submitted Event
        """
        queue_len = len(self.event_queue)
        if self.max_size is not None and queue_len >= self.max_size:
            return None
        self.event_queue.append(event)
        return queue_len

    async def clean_event(self, event: Event) -> None:
        if event in self.event_queue:
            async with self.delete_lock:
                self.event_queue.remove(event)

    async def broadcast_live_estimations(self) -> None:
        """
        Runs 2 functions sequentially instead of concurrently. Otherwise dced clients are tried to get deleted twice.
        """
        if self.live_updates:
            await self.broadcast_estimations()

    async def gather_event_data(self, event: Event) -> bool:
        """
        Gather data for the event

        Parameters:
            event:
        """
        if not event.data:
            client_awake = await self.send_message(event, {"msg": "send_data"})
            if not client_awake:
                return False
            event.data = await self.get_message(event)
        return True

    async def notify_clients(self) -> None:
        """
        Notify clients about events statuses in the queue periodically.
        """
        while not self.stopped:
            await asyncio.sleep(self.update_intervals)
            if self.event_queue:
                await self.broadcast_estimations()

    async def broadcast_estimations(self) -> None:
        estimation = self.get_estimation()
        # Send all messages concurrently
        await asyncio.gather(
            *[
                self.send_estimation(event, estimation, rank)
                for rank, event in enumerate(self.event_queue)
            ]
        )

    async def send_estimation(
        self, event: Event, estimation: Estimation, rank: int
    ) -> Estimation:
        """
        Send estimation about ETA to the client.

        Parameters:
            event:
            estimation:
            rank:
        """
        estimation.rank = rank

        if self.avg_concurrent_process_time is not None:
            estimation.rank_eta = (
                estimation.rank * self.avg_concurrent_process_time
                + self.avg_process_time
            )
            if None not in self.active_jobs:
                # Add estimated amount of time for a thread to get empty
                estimation.rank_eta += self.avg_concurrent_process_time
        client_awake = await self.send_message(event, estimation.dict())
        if not client_awake:
            await self.clean_event(event)
        return estimation

    def update_estimation(self, duration: float) -> None:
        """
        Update estimation by last x element's average duration.

        Parameters:
            duration:
        """
        self.duration_history_total += duration
        self.duration_history_count += 1
        self.avg_process_time = (
            self.duration_history_total / self.duration_history_count
        )
        self.avg_concurrent_process_time = self.avg_process_time / min(
            self.max_thread_count, self.duration_history_count
        )
        self.queue_duration = self.avg_concurrent_process_time * len(self.event_queue)

    def get_estimation(self) -> Estimation:
        return Estimation(
            queue_size=len(self.event_queue),
            avg_event_process_time=self.avg_process_time,
            avg_event_concurrent_process_time=self.avg_concurrent_process_time,
            queue_eta=self.queue_duration,
        )

    def get_request_params(self, websocket: fastapi.WebSocket) -> Dict[str, Any]:
        return {
            "url": str(websocket.url),
            "headers": dict(websocket.headers),
            "query_params": dict(websocket.query_params),
            "path_params": dict(websocket.path_params),
            "client": dict(host=websocket.client.host, port=websocket.client.port),  # type: ignore
        }

    async def call_prediction(self, events: List[Event], batch: bool):
        data = events[0].data
        assert data is not None, "No event data"
        token = events[0].token
        data.event_id = events[0]._id if not batch else None
        try:
            data.request = self.get_request_params(events[0].websocket)
        except ValueError:
            pass

        if batch:
            data.data = list(zip(*[event.data.data for event in events if event.data]))
            data.request = [
                self.get_request_params(event.websocket)
                for event in events
                if event.data
            ]
            data.batched = True

        response = await AsyncRequest(
            method=AsyncRequest.Method.POST,
            url=f"{self.server_path}api/predict",
            json=dict(data),
            headers={"Authorization": f"Bearer {self.access_token}"},
            cookies={"access-token": token} if token is not None else None,
        )
        return response

    async def process_events(self, events: List[Event], batch: bool) -> None:
        awake_events: List[Event] = []
        try:
            for event in events:
                client_awake = await self.gather_event_data(event)
                if client_awake:
                    client_awake = await self.send_message(
                        event, {"msg": "process_starts"}
                    )
                if client_awake:
                    awake_events.append(event)
            if not awake_events:
                return
            begin_time = time.time()
            response = await self.call_prediction(awake_events, batch)
            if response.has_exception:
                for event in awake_events:
                    await self.send_message(
                        event,
                        {
                            "msg": "process_completed",
                            "output": {"error": str(response.exception)},
                            "success": False,
                        },
                    )
            elif response.json.get("is_generating", False):
                old_response = response
                while response.json.get("is_generating", False):
                    # Python 3.7 doesn't have named tasks.
                    # In order to determine if a task was cancelled, we
                    # ping the websocket to see if it was closed mid-iteration.
                    if sys.version_info < (3, 8):
                        is_alive = await self.send_message(event, {"msg": "alive?"})
                        if not is_alive:
                            return
                    old_response = response
                    open_ws = []
                    for event in awake_events:
                        open = await self.send_message(
                            event,
                            {
                                "msg": "process_generating",
                                "output": old_response.json,
                                "success": old_response.status == 200,
                            },
                        )
                        open_ws.append(open)
                    awake_events = [
                        e for e, is_open in zip(awake_events, open_ws) if is_open
                    ]
                    if not awake_events:
                        return
                    response = await self.call_prediction(awake_events, batch)
                for event in awake_events:
                    if response.status != 200:
                        relevant_response = response
                    else:
                        relevant_response = old_response

                    await self.send_message(
                        event,
                        {
                            "msg": "process_completed",
                            "output": relevant_response.json,
                            "success": relevant_response.status == 200,
                        },
                    )
            else:
                output = copy.deepcopy(response.json)
                for e, event in enumerate(awake_events):
                    if batch and "data" in output:
                        output["data"] = list(zip(*response.json.get("data")))[e]
                    await self.send_message(
                        event,
                        {
                            "msg": "process_completed",
                            "output": output,
                            "success": response.status == 200,
                        },
                    )
            end_time = time.time()
            if response.status == 200:
                self.update_estimation(end_time - begin_time)
        finally:
            for event in awake_events:
                try:
                    await event.disconnect()
                except Exception:
                    pass
            self.active_jobs[self.active_jobs.index(events)] = None
            for event in awake_events:
                await self.clean_event(event)
                # Always reset the state of the iterator
                # If the job finished successfully, this has no effect
                # If the job is cancelled, this will enable future runs
                # to start "from scratch"
                await self.reset_iterators(event.session_hash, event.fn_index)

    async def send_message(self, event, data: Dict) -> bool:
        try:
            await event.websocket.send_json(data=data)
            return True
        except:
            await self.clean_event(event)
            return False

    async def get_message(self, event) -> PredictBody | None:
        try:
            data = await event.websocket.receive_json()
            return PredictBody(**data)
        except:
            await self.clean_event(event)
            return None

    async def reset_iterators(self, session_hash: str, fn_index: int):
        await AsyncRequest(
            method=AsyncRequest.Method.POST,
            url=f"{self.server_path}reset",
            json={
                "session_hash": session_hash,
                "fn_index": fn_index,
            },
        )
